import numpy as np
import os
import sys
from plda import PldaStats,PldaEstimation
import argparse

def ComputePLDA(ivectors_reader, mean_out, plda_out):
    print('Load vecs and accumulate the stats of vecs.....')
    vector_data = np.load(ivectors_reader)['vectors']
    spker_label = np.load(ivectors_reader)['spker_label']
    utt_label = np.load(ivectors_reader)['utt_label']

    print("Substract global mean")
    global_mean = np.mean(vector_data, axis=0)
    vector_data = vector_data - global_mean

    with open(mean_out, 'w') as f:
        f.write(' [ ' + ' '.join(list(map(str,list(global_mean)))) + ' ]\n')

    spk2vec_dict = {}
    dim = np.shape(vector_data)[1]
    for i in range(len(vector_data)):
        spk = spker_label[i]
        if spk not in spk2vec_dict.keys():
            spk2vec_dict[spk] = np.reshape(vector_data[i], (-1, dim))
        else:
            spk2vec_dict[spk] = np.vstack((spk2vec_dict[spk], vector_data[i]))

    plda_stats = PldaStats(dim)
    for key in spk2vec_dict.keys():
        vectors = np.array(spk2vec_dict[key], dtype=float)
        weight = 1.0
        plda_stats.add_samples(weight,vectors)

    print('Estimate the parameters of PLDA by EM algorithm...')
    plda_stats.sort()
    plda_estimator = PldaEstimation(plda_stats)
    plda_estimator.estimate()
    print('Save the parameters for the PLDA adaptation...')
    plda_estimator.plda_write(plda_out + '.ori')
    plda_trans = plda_estimator.get_output()
    print('Save the parameters for scoring directly, which is the same with the plda in kaldi...')
    plda_trans.plda_trans_write(plda_out)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--train-npz', default='data/xvector.npz', help='npz file of training vector')
    parser.add_argument(
        '--global-mean', default='mean.vec', help='file of global mean vector in Kaldi')
    parser.add_argument(
        '--plda', default='plda', help='file of plda model in Kaldi')

    args = parser.parse_args()

    if not os.path.exists(os.path.dirname(args.plda)):
        os.makedirs(os.path.dirname(args.plda))

    ComputePLDA(args.train_npz, args.global_mean, args.plda)

